{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "adding_metadata_to_tflite_model_colab_notebook.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2022 The TensorFlow Authors."
      ],
      "metadata": {
        "id": "U2F9EQ7PlPgW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "metadata": {
        "id": "-UoKBJEqlfR7"
      },
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Tensorflow Lite Gesture Classification Example \bAdding Metadata Script\n",
        "\n",
        "\n",
        "This guide shows how you can go about adding the metadata into Tensorflow Lite model.\n",
        "\n",
        "Run all steps in-order. At the end, `model_metadata.tflite` file will be downloaded."
      ],
      "metadata": {
        "id": "Hh7ruhMFlr_V"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jVQAmH1w1wu5"
      },
      "outputs": [],
      "source": [
        "# Install the TensorFlow Lite Support Pypi package.\n",
        "!pip install tflite-support-nightly"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Import the required packages.\n",
        "from tflite_support.metadata_writers import image_classifier\n",
        "from tflite_support.metadata_writers import metadata_info\n",
        "from tflite_support.metadata_writers import writer_utils\n",
        "from tflite_support import metadata_schema_py_generated as _metadata_fb\n",
        "\n",
        "from google.colab import files"
      ],
      "metadata": {
        "id": "Qmx4P7Au12oH"
      },
      "execution_count": 21,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Upload your model and labels here.\n",
        "files.upload()"
      ],
      "metadata": {
        "id": "KUdyCjopkEO0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Declare the resource paths\n",
        "_MODEL_PATH = \"model.tflite\"\n",
        "_LABEL_PATH = \"labels.txt\"\n",
        "_SAVE_TO_PATH = \"model_metadata.tflite\""
      ],
      "metadata": {
        "id": "1rXRefHHIMzK"
      },
      "execution_count": 23,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Format the inline label file to 8 lines label file.\n",
        "search_text = \",\"\n",
        "replace_text = \"\\n\"\n",
        "  \n",
        "# Opening our label file in read only\n",
        "with open(_LABEL_PATH, 'r') as file:\n",
        "  \n",
        "    # Reading the content of the file\n",
        "    data = file.read()\n",
        "  \n",
        "    # Searching and replacing the text\n",
        "    data = data.replace(search_text, replace_text)\n",
        "  \n",
        "# Opening our label file in write only\n",
        "with open(_LABEL_PATH, 'w') as file:\n",
        "  \n",
        "    # Writing the replaced data in our label file\n",
        "    file.write(data)"
      ],
      "metadata": {
        "id": "Za0zcM97HMK_"
      },
      "execution_count": 27,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_buffer = writer_utils.load_file(_MODEL_PATH)\n",
        "\n",
        "# Create general model information.\n",
        "general_md = metadata_info.GeneralMd(\n",
        "    name=\"Guesture Classifier\",\n",
        "    version=\"v1\",\n",
        "    description=(\"Identify the most prominent gesture in the image from a \"\n",
        "                 \"known set of categories.\"),\n",
        "    author=\"TensorFlow Lite\",\n",
        "    licenses=\"Apache License. Version 2.0\")\n",
        "\n",
        "# Create input tensor information.\n",
        "input_md = metadata_info.InputImageTensorMd(\n",
        "    name=\"input image\",\n",
        "    description=(\"Input image to be classified. The expected image is \"\n",
        "                 \"224 x 224, with three channels (red, blue, and green) per \"\n",
        "                 \"pixel. Each element in the tensor is a value between min and \"\n",
        "                 \"max, where (per-channel) min is [-1] and max is [1].\"),\n",
        "    norm_mean=[127.5],\n",
        "    norm_std=[127.5],\n",
        "    color_space_type=_metadata_fb.ColorSpaceType.RGB,\n",
        "    tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])\n",
        "\n",
        "# Create output tensor information.\n",
        "output_md = metadata_info.ClassificationTensorMd(\n",
        "    name=\"probability\",\n",
        "    description=\"Probabilities of the 8 labels respectively.\",\n",
        "    label_files=[\n",
        "        metadata_info.LabelFileMd(file_path=_LABEL_PATH, locale=\"en\")\n",
        "    ],\n",
        "    tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0])\n",
        "\n",
        "ImageClassifierWriter = image_classifier.MetadataWriter\n",
        "# Create the metadata writer.\n",
        "writer = ImageClassifierWriter.create_from_metadata_info(\n",
        "    model_buffer, general_md, input_md, output_md)\n",
        "\n",
        "# Verify the metadata generated by metadata writer.\n",
        "print(writer.get_metadata_json())\n",
        "\n",
        "# Populate the metadata into the model.\n",
        "writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)"
      ],
      "metadata": {
        "id": "2gfWkoDF14mA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Download the model after add metadata.\n",
        "files.download(_SAVE_TO_PATH)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "OQmXw6MBlhla",
        "outputId": "6b1ea079-ee07-4c90-e152-add1844347d5"
      },
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_0b7c0687-5a92-4369-9646-fc0b5910cc8a\", \"model_metadata.tflite\", 5879897)"
            ]
          },
          "metadata": {}
        }
      ]
    }
  ]
}